import torch
import torchvision
import torch.nn as nn

if torch.cuda.is_available():
    print("Working on GPU")
else:
    print("Working on CPU")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                #in_channels=3,
                out_channels=32,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
        )
        # fully connected layer, output 10 classes
        self.fc = nn.Linear(64 * 7 * 7, 128)
        self.out = nn.Linear(128, 10)
        #self.out = nn.Linear(2048, 10)
        #self.fc2 = nn.Linear(600, 120)
        #self.fc3 = nn.Linear(120, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        x = nn.functional.relu(self.fc(x))
        output = self.out(x)
        #x = self.fc1(x)
        #x = self.fc2(x)
        #output = self.fc3(x)
        return output, x  # return x for visualization

def CreateCNN():
    torch.manual_seed(1)
    cnn = CNN()
    cnn = cnn.to(device)
    cnn = nn.DataParallel(cnn)
    #print(cnn)
    #print(cnn.module.conv1[0].bias.data)
    return cnn